{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import  os\n",
    "import  tensorflow as tf\n",
    "import  numpy as np\n",
    "from    tensorflow import keras\n",
    "from    tensorflow.keras import layers\n",
    "\n",
    "\n",
    "tf.random.set_seed(22)\n",
    "np.random.seed(22)\n",
    "os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'\n",
    "assert tf.__version__.startswith('2.')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "x_train shape: (25000, 80) tf.Tensor(1, shape=(), dtype=int64) tf.Tensor(0, shape=(), dtype=int64)\n",
      "x_test shape: (25000, 80)\n"
     ]
    }
   ],
   "source": [
    "batchsz = 128\n",
    "\n",
    "# the most frequest words\n",
    "total_words = 10000\n",
    "max_review_len = 80\n",
    "embedding_len = 100\n",
    "(x_train, y_train), (x_test, y_test) = keras.datasets.imdb.load_data(num_words=total_words)\n",
    "# x_train:[b, 80]\n",
    "# x_test: [b, 80]\n",
    "x_train = keras.preprocessing.sequence.pad_sequences(x_train, maxlen=max_review_len)\n",
    "x_test = keras.preprocessing.sequence.pad_sequences(x_test, maxlen=max_review_len)\n",
    "\n",
    "db_train = tf.data.Dataset.from_tensor_slices((x_train, y_train))\n",
    "db_train = db_train.shuffle(1000).batch(batchsz, drop_remainder=True)\n",
    "db_test = tf.data.Dataset.from_tensor_slices((x_test, y_test))\n",
    "db_test = db_test.batch(batchsz, drop_remainder=True)\n",
    "print('x_train shape:', x_train.shape, tf.reduce_max(y_train), tf.reduce_min(y_train))\n",
    "print('x_test shape:', x_test.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/4\n",
      "195/195 [==============================] - 74s 381ms/step - loss: 0.6770 - accuracy: 0.5534 - val_loss: 0.0000e+00 - val_accuracy: 0.0000e+00\n",
      "Epoch 2/4\n",
      "195/195 [==============================] - 51s 259ms/step - loss: 0.4345 - accuracy: 0.8041 - val_loss: 0.3883 - val_accuracy: 0.8292\n",
      "Epoch 3/4\n",
      "195/195 [==============================] - 50s 256ms/step - loss: 0.3131 - accuracy: 0.8739 - val_loss: 0.4266 - val_accuracy: 0.8326\n",
      "Epoch 4/4\n",
      "195/195 [==============================] - 47s 239ms/step - loss: 0.2670 - accuracy: 0.8955 - val_loss: 0.7790 - val_accuracy: 0.6079\n",
      "195/195 [==============================] - 14s 73ms/step - loss: 0.7790 - accuracy: 0.6079\n"
     ]
    }
   ],
   "source": [
    "\n",
    "\n",
    "class MyRNN(keras.Model):\n",
    "\n",
    "    def __init__(self, units):\n",
    "        super(MyRNN, self).__init__()\n",
    "\n",
    "\n",
    "        # transform text to embedding representation\n",
    "        # [b, 80] => [b, 80, 100]\n",
    "        self.embedding = layers.Embedding(total_words, embedding_len,\n",
    "                                          input_length=max_review_len)\n",
    "\n",
    "        # [b, 80, 100] , h_dim: 64\n",
    "        self.rnn = keras.Sequential([\n",
    "            layers.SimpleRNN(units, dropout=0.5, return_sequences=True, unroll=True),\n",
    "            layers.SimpleRNN(units, dropout=0.5, unroll=True)\n",
    "        ])\n",
    "\n",
    "\n",
    "        # fc, [b, 80, 100] => [b, 64] => [b, 1]\n",
    "        self.outlayer = layers.Dense(1)\n",
    "\n",
    "    def call(self, inputs, training=None):\n",
    "        \"\"\"\n",
    "        net(x) net(x, training=True) :train mode\n",
    "        net(x, training=False): test\n",
    "        :param inputs: [b, 80]\n",
    "        :param training:\n",
    "        :return:\n",
    "        \"\"\"\n",
    "        # [b, 80]\n",
    "        x = inputs\n",
    "        # embedding: [b, 80] => [b, 80, 100]\n",
    "        x = self.embedding(x)\n",
    "        # rnn cell compute\n",
    "        # x: [b, 80, 100] => [b, 64]\n",
    "        x = self.rnn(x)\n",
    "\n",
    "        # out: [b, 64] => [b, 1]\n",
    "        x = self.outlayer(x)\n",
    "        # p(y is pos|x)\n",
    "        prob = tf.sigmoid(x)\n",
    "\n",
    "        return prob\n",
    "\n",
    "def main():\n",
    "    units = 64\n",
    "    epochs = 4\n",
    "\n",
    "    model = MyRNN(units)\n",
    "    model.compile(optimizer = keras.optimizers.Adam(0.001),\n",
    "                  loss = tf.losses.BinaryCrossentropy(),\n",
    "                  metrics=['accuracy'])\n",
    "    model.fit(db_train, epochs=epochs, validation_data=db_test)\n",
    "\n",
    "    model.evaluate(db_test)\n",
    "\n",
    "\n",
    "if __name__ == '__main__':\n",
    "    main()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
